import pyopencl as cl
import numpy as np
import math

# -------------------------------
# Constants
# -------------------------------
PHI = 1.6180339887
D_TOTAL = 4096
INSTANCES = 8
SLOTS_PER = 4

# -------------------------------
# Setup OpenCL Context
# -------------------------------
platforms = cl.get_platforms()
gpu_devices = [d for p in platforms for d in p.get_devices(cl.device_type.GPU)]
ctx = cl.Context(devices=gpu_devices)
queue = cl.CommandQueue(ctx)

# -------------------------------
# Initialize lattice and slot arrays
# -------------------------------
lattice_host = np.zeros(D_TOTAL, dtype=np.float32)

# Example slot mappings
slots_control = np.arange(0, INSTANCES*SLOTS_PER, dtype=np.int32)
slots_workspace = np.arange(1024, 1024 + INSTANCES*SLOTS_PER, dtype=np.int32)
slots_console = np.arange(24, 32, dtype=np.int32)

# OpenCL buffers
mf = cl.mem_flags
lattice_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=lattice_host)
slots_control_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=slots_control)
slots_workspace_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=slots_workspace)

# -------------------------------
# OpenCL Kernel (AMD-compatible)
# -------------------------------
kernel_source = """
__kernel void lattice_evolve(__global float *lattice,
                             const int D_TOTAL,
                             __global const int *slots_control,
                             const int n_control,
                             __global const int *slots_workspace,
                             const int n_workspace,
                             const float phi) {
    int gid = get_global_id(0);

    // Wave interference for control slots
    if (gid < n_control) {
        int idx = slots_control[gid];
        int left = (idx == 0) ? idx : idx-1;
        int right = (idx == D_TOTAL-1) ? idx : idx+1;
        lattice[idx] += 0.5f * (lattice[left] - lattice[right]);
    }

    // Strand blending: control -> workspace
    if (gid < n_control && gid < n_workspace) {
        int c_idx = slots_control[gid];
        int w_idx = slots_workspace[gid];
        if (lattice[c_idx] > sqrt(phi)) {
            lattice[w_idx] += 1.0f;
        }
    }

    // Threshold projection for workspace
    if (gid < n_workspace) {
        int w_idx = slots_workspace[gid];
        lattice[w_idx] = (lattice[w_idx] >= sqrt(phi)) ? 1.0f : 0.0f;
    }
}
"""

program = cl.Program(ctx, kernel_source).build()

# -------------------------------
# Inject random snapshots into workspace
# -------------------------------
kernel_snapshot = np.random.rand(len(slots_workspace)).astype(np.float32)
cl.enqueue_copy(queue, lattice_buf, lattice_host)
cl.enqueue_copy(queue, lattice_buf, kernel_snapshot, device_offset=1024 * 4)  # bytes

# -------------------------------
# Evolution Loop
# -------------------------------
EVOLUTION_TICKS = 100
for tick in range(EVOLUTION_TICKS):
    program.lattice_evolve(
        queue, (len(slots_workspace),), None,
        lattice_buf,
        np.int32(D_TOTAL),
        slots_control_buf, np.int32(len(slots_control)),
        slots_workspace_buf, np.int32(len(slots_workspace)),
        np.float32(PHI)
    )

    # Read back console every 20 ticks
    if tick % 20 == 0:
        cl.enqueue_copy(queue, lattice_host, lattice_buf)
        console_out = ''.join(['#' if lattice_host[i] > 0 else '.' for i in slots_console])
        print(f"[Tick {tick}] Console: {console_out}")

# -------------------------------
# Final Lattice Snapshot
# -------------------------------
cl.enqueue_copy(queue, lattice_host, lattice_buf)
print("HDGL-native Debian Bootstrap Complete (OpenCL)")
print("Control + first 16 workspace slots:")
print(lattice_host[:len(slots_control)+16])
